import os
import csv
import pandas as pd

wd = '/disk/data3/census-ipums/v2019/csv'
os.chdir(wd)

data_dir = '/disk/homedirs/nber/berkes/pandemic_data/'

# load state table to go from full names to abbreviations
state_info_df = pd.read_csv(os.path.join(data_dir, 'state_table.csv'))
state_info_df = state_info_df[['STUSPS', 'NAME', 'STATEFP']]

# list of cities for which we want to extract information
cities_df = pd.read_csv(os.path.join(data_dir, 'cities_list.csv'))
cities_df['State'] = cities_df['State'].str.strip()
cities_df = cities_df.merge(state_info_df, how='left', left_on='State', right_on='STUSPS')

# list of City, States pairs to consider
citystate_set = set(cities_df['City'].str.lower() + ', ' + cities_df['STATEFP'].astype('string'))

# lit: #categories : 5 (0-4 and maybe 9)
# school: 2 categories
# sex: 2 categories
# race: 100 white
# bpl: 453 Germany
# mtongue: 2 German + German dialects (e.g., Swiss German)

vars_dict = {}

for decade in ['1900', '1910', '1920', '1930']:
    print(decade)
    # open the census file
    with open(decade + '.csv') as census_file:
        # identify the columns of interest
        census_reader = csv.reader(census_file, delimiter=',')
        headers = next(census_reader)
        city_index = headers.index('stdcity')
        state_index = headers.index('statefip')
        sex_index = headers.index('sex')
        race_index = headers.index('race')
        school_index = headers.index('school')
        lit_index = headers.index('lit')
        birth_index = headers.index('bpl')
        age_index = headers.index('age')
        # variable not available in 1910
        if 'mtongue' in headers:
            mtongue_index = headers.index('mtongue')
        else:
            mtongue_index = -99

        # iterate over the observations
        for line in census_reader:
            city = line[city_index].lower()
            if city == '':
                continue
            # aggregate new york boroughs
            elif city == 'new york' or city == 'brooklyn' or city == 'queens' or \
                    city == 'manhattan' or city == 'staten island' or city == 'bronx' or city == 'the bronx':
                city = 'new york city'
            elif city == 'pittsburg':
                city = 'pittsburgh'
            elif ' louis' in city:
                city = 'st louis'
            # this test is more precise than for st louis because it turns out that
            # the Census also contains North/South/East/West St. Paul. Including these
            # makes the population quite different compared to the publicly available
            # aggregated data which suggests that were separate entities at the beginning
            # of the 20th century
            elif 'st. paul' == city:
                city = 'st paul'
            state_fip = line[state_index]

            # if city/state pair not in the list of cities to consider => continue
            # else update the final dataset
            if city + ', ' + state_fip not in citystate_set:
                continue
            else:
                vars_dict.setdefault(city, {}).setdefault(decade, {})

                vars_dict[city][decade].setdefault('population', 0)
                vars_dict[city][decade]['population'] += 1

                vars_dict[city][decade].setdefault('sex1', 0)
                vars_dict[city][decade].setdefault('sex2', 0)
                if line[sex_index] == '1':
                    vars_dict[city][decade]['sex1'] += 1
                elif line[sex_index] == '2':
                    vars_dict[city][decade]['sex2'] += 1

                vars_dict[city][decade].setdefault('white', 0)
                if line[race_index] == '100':
                    vars_dict[city][decade]['white'] += 1

                vars_dict[city][decade].setdefault('school1', 0)
                vars_dict[city][decade].setdefault('school2', 0)
                if line[school_index] == '1':
                    vars_dict[city][decade]['school1'] += 1
                elif line[school_index] == '2':
                    vars_dict[city][decade]['school2'] += 1

                vars_dict[city][decade].setdefault('lit0', 0)
                vars_dict[city][decade].setdefault('lit1', 0)
                vars_dict[city][decade].setdefault('lit2', 0)
                vars_dict[city][decade].setdefault('lit3', 0)
                vars_dict[city][decade].setdefault('lit4', 0)
                vars_dict[city][decade].setdefault('lit9', 0)
                if line[lit_index] == '0':
                    vars_dict[city][decade]['lit0'] += 1
                elif line[lit_index] == '1':
                    vars_dict[city][decade]['lit1'] += 1
                elif line[lit_index] == '2':
                    vars_dict[city][decade]['lit2'] += 1
                elif line[lit_index] == '3':
                    vars_dict[city][decade]['lit3'] += 1
                elif line[lit_index] == '4':
                    vars_dict[city][decade]['lit4'] += 1
                elif line[lit_index] == '9':
                    vars_dict[city][decade]['lit9'] += 1

                vars_dict[city][decade].setdefault('german', 0)
                if line[birth_index] != '' and int(int(line[birth_index])/100) == 453:
                    vars_dict[city][decade]['german'] += 1

                vars_dict[city][decade].setdefault('german_mtongue', 0)
                if mtongue_index != -99:
                    if line[mtongue_index] != '' and int(int(line[mtongue_index])/100) == 2:
                        vars_dict[city][decade]['german_mtongue'] += 1
                else:
                    vars_dict[city][decade]['german_mtongue'] = -1

                vars_dict[city][decade].setdefault('age020', 0)
                vars_dict[city][decade].setdefault('age2030', 0)
                vars_dict[city][decade].setdefault('age3040', 0)
                vars_dict[city][decade].setdefault('age4050', 0)
                vars_dict[city][decade].setdefault('age5060', 0)
                vars_dict[city][decade].setdefault('age6070', 0)
                vars_dict[city][decade].setdefault('age70plus', 0)
                vars_dict[city][decade].setdefault('sum_ages', 0)
                vars_dict[city][decade].setdefault('age_denominator', 0)
                if line[age_index] != '':
                    line[age_index] = int(line[age_index])

                    if line[age_index] < 20:
                        vars_dict[city][decade]['age020'] += 1
                    elif 30 > line[age_index] >= 20:
                        vars_dict[city][decade]['age2030'] += 1
                    elif 40 > line[age_index] >= 30:
                        vars_dict[city][decade]['age3040'] += 1
                    elif 50 > line[age_index] >= 40:
                        vars_dict[city][decade]['age4050'] += 1
                    elif 60 > line[age_index] >= 50:
                        vars_dict[city][decade]['age5060'] += 1
                    elif 70 > line[age_index] >= 60:
                        vars_dict[city][decade]['age6070'] += 1
                    elif line[age_index] >= 70:
                        vars_dict[city][decade]['age70plus'] += 1

                    vars_dict[city][decade]['sum_ages'] += line[age_index]
                    vars_dict[city][decade]['age_denominator'] += 1


with open('/disk/homedirs/nber/berkes/pandemic_data/city_infos_w_germans_non_imputed_ages.csv', 'w') as fout:
    fout.write('city,decade,population,sex1,sex2,white,school1,school2,lit0,lit1,lit2,lit3,'
               'lit4,lit9,german,german_mtongue,age020,age2030,age3040,age4050,age5060,age6070,age70plus,'
               'sum_ages,age_denominator\n')
    for city in vars_dict:
        for decade in vars_dict[city]:
            fout.write(city + ',' + decade + ',' + str(vars_dict[city][decade]['population']) + ',' +
                       str(vars_dict[city][decade]['sex1']) + ',' +
                       str(vars_dict[city][decade]['sex2']) + ',' +
                       str(vars_dict[city][decade]['white']) + ',' +
                       str(vars_dict[city][decade]['school1']) + ',' +
                       str(vars_dict[city][decade]['school2']) + ',' +
                       str(vars_dict[city][decade]['lit0']) + ',' +
                       str(vars_dict[city][decade]['lit1']) + ',' +
                       str(vars_dict[city][decade]['lit2']) + ',' +
                       str(vars_dict[city][decade]['lit3']) + ',' +
                       str(vars_dict[city][decade]['lit4']) + ',' +
                       str(vars_dict[city][decade]['lit9']) + ',' +
                       str(vars_dict[city][decade]['german']) + ',' +
                       str(vars_dict[city][decade]['german_mtongue']) + ',' +
                       str(vars_dict[city][decade]['age020']) + ',' +
                       str(vars_dict[city][decade]['age2030']) + ',' +
                       str(vars_dict[city][decade]['age3040']) + ',' +
                       str(vars_dict[city][decade]['age4050']) + ',' +
                       str(vars_dict[city][decade]['age5060']) + ',' +
                       str(vars_dict[city][decade]['age6070']) + ',' +
                       str(vars_dict[city][decade]['age70plus']) + ',' +
                       str(vars_dict[city][decade]['sum_ages']) + ',' +
                       str(vars_dict[city][decade]['age_denominator']) + '\n')
